{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Random Forest Example\n",
    "\n",
    "Implement Random Forest algorithm with TensorFlow, and apply it to classify \n",
    "handwritten digit images. This example is using the MNIST database of \n",
    "handwritten digits as training samples (http://yann.lecun.com/exdb/mnist/).\n",
    "\n",
    "- Author: Aymeric Damien\n",
    "- Project: https://github.com/aymericdamien/TensorFlow-Examples/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.ops import resources\n",
    "from tensorflow.contrib.tensor_forest.python import tensor_forest\n",
    "\n",
    "# Ignore all GPUs, tf random forest does not benefit from it.\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n",
      "Extracting /tmp/data/train-images-idx3-ubyte.gz\n",
      "Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n",
      "Extracting /tmp/data/train-labels-idx1-ubyte.gz\n",
      "Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n",
      "Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n",
      "Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n",
      "Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "# Import MNIST data\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "num_steps = 500 # Total steps to train\n",
    "batch_size = 1024 # The number of samples per batch\n",
    "num_classes = 10 # The 10 digits\n",
    "num_features = 784 # Each image is 28x28 pixels\n",
    "num_trees = 10\n",
    "max_nodes = 1000\n",
    "\n",
    "# Input and Target data\n",
    "X = tf.placeholder(tf.float32, shape=[None, num_features])\n",
    "# For random forest, labels must be integers (the class id)\n",
    "Y = tf.placeholder(tf.int32, shape=[None])\n",
    "\n",
    "# Random Forest Parameters\n",
    "hparams = tensor_forest.ForestHParams(num_classes=num_classes,\n",
    "                                      num_features=num_features,\n",
    "                                      num_trees=num_trees,\n",
    "                                      max_nodes=max_nodes).fill()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Constructing forest with params = \n",
      "INFO:tensorflow:{'valid_leaf_threshold': 1, 'split_after_samples': 250, 'num_output_columns': 11, 'feature_bagging_fraction': 1.0, 'split_initializations_per_input': 3, 'bagged_features': None, 'min_split_samples': 5, 'max_nodes': 1000, 'num_features': 784, 'num_trees': 10, 'num_splits_to_consider': 784, 'base_random_seed': 0, 'num_outputs': 1, 'dominate_fraction': 0.99, 'max_fertile_nodes': 500, 'bagged_num_features': 784, 'dominate_method': 'bootstrap', 'bagging_fraction': 1.0, 'regression': False, 'num_classes': 10}\n",
      "INFO:tensorflow:training graph for tree: 0\n",
      "INFO:tensorflow:training graph for tree: 1\n",
      "INFO:tensorflow:training graph for tree: 2\n",
      "INFO:tensorflow:training graph for tree: 3\n",
      "INFO:tensorflow:training graph for tree: 4\n",
      "INFO:tensorflow:training graph for tree: 5\n",
      "INFO:tensorflow:training graph for tree: 6\n",
      "INFO:tensorflow:training graph for tree: 7\n",
      "INFO:tensorflow:training graph for tree: 8\n",
      "INFO:tensorflow:training graph for tree: 9\n"
     ]
    }
   ],
   "source": [
    "# Build the Random Forest\n",
    "forest_graph = tensor_forest.RandomForestGraphs(hparams)\n",
    "# Get training graph and loss\n",
    "train_op = forest_graph.training_graph(X, Y)\n",
    "loss_op = forest_graph.training_loss(X, Y)\n",
    "\n",
    "# Measure the accuracy\n",
    "infer_op, _, _ = forest_graph.inference_graph(X)\n",
    "correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))\n",
    "accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "\n",
    "# Initialize the variables (i.e. assign their default value) and forest resources\n",
    "init_vars = tf.group(tf.global_variables_initializer(),\n",
    "    resources.initialize_resources(resources.shared_resources()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 1, Loss: -0.000000, Acc: 0.112305\n",
      "Step 50, Loss: -123.800003, Acc: 0.863281\n",
      "Step 100, Loss: -274.200012, Acc: 0.863281\n",
      "Step 150, Loss: -425.399994, Acc: 0.872070\n",
      "Step 200, Loss: -582.799988, Acc: 0.917969\n",
      "Step 250, Loss: -740.200012, Acc: 0.912109\n",
      "Step 300, Loss: -895.799988, Acc: 0.939453\n",
      "Step 350, Loss: -998.000000, Acc: 0.924805\n",
      "Step 400, Loss: -998.000000, Acc: 0.940430\n",
      "Step 450, Loss: -998.000000, Acc: 0.914062\n",
      "Step 500, Loss: -998.000000, Acc: 0.927734\n",
      "Test Accuracy: 0.9204\n"
     ]
    }
   ],
   "source": [
    "# Start TensorFlow session\n",
    "sess = tf.train.MonitoredSession()\n",
    "\n",
    "# Run the initializer\n",
    "sess.run(init_vars)\n",
    "\n",
    "# Training\n",
    "for i in range(1, num_steps + 1):\n",
    "    # Prepare Data\n",
    "    # Get the next batch of MNIST data (only images are needed, not labels)\n",
    "    batch_x, batch_y = mnist.train.next_batch(batch_size)\n",
    "    _, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})\n",
    "    if i % 50 == 0 or i == 1:\n",
    "        acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})\n",
    "        print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))\n",
    "\n",
    "# Test Model\n",
    "test_x, test_y = mnist.test.images, mnist.test.labels\n",
    "print(\"Test Accuracy:\", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  },
  "varInspector": {
   "cols": {
    "lenName": 16.0,
    "lenType": 16.0,
    "lenVar": 40.0
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
